{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 模型量化\n",
    "\n",
    "\n",
    " 1. 使用 GPTQ 量化 OPT-6.7B 模型。课程代码（ https://github.com/DjangoPeng/LLM-quickstart/blob/main/transformers/AutoGPTQ_transformers.ipynb ）\n",
    "\n",
    " 2. 使用 AWQ 量化 Facebook OPT-6.7B 模型。Facebook OPT 模型地址： https://huggingface.co/facebook?search_models=opt\n",
    "\n",
    "    课程代码： https://github.com/DjangoPeng/LLM-quickstart/blob/main/transformers/AWQ_transformers.ipynb\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. 使用 AWQ 量化 Facebook OPT-6.7B 模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import subprocess\n",
    "import os\n",
    "\n",
    "result = subprocess.run('bash -c \"source /etc/network_turbo && env | grep proxy\"', shell=True, capture_output=True, text=True)\n",
    "output = result.stdout\n",
    "for line in output.splitlines():\n",
    "    if '=' in line:\n",
    "        var, value = line.split('=', 1)\n",
    "        os.environ[var] = value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sat Feb 17 02:14:54 2024       \n",
      "+---------------------------------------------------------------------------------------+\n",
      "| NVIDIA-SMI 535.146.02             Driver Version: 535.146.02   CUDA Version: 12.2     |\n",
      "|-----------------------------------------+----------------------+----------------------+\n",
      "| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
      "| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |\n",
      "|                                         |                      |               MIG M. |\n",
      "|=========================================+======================+======================|\n",
      "|   0  NVIDIA GeForce RTX 3090        On  | 00000000:41:00.0 Off |                  N/A |\n",
      "| 32%   33C    P8              21W / 350W |      0MiB / 24576MiB |      0%      Default |\n",
      "|                                         |                      |                  N/A |\n",
      "+-----------------------------------------+----------------------+----------------------+\n",
      "                                                                                         \n",
      "+---------------------------------------------------------------------------------------+\n",
      "| Processes:                                                                            |\n",
      "|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |\n",
      "|        ID   ID                                                             Usage      |\n",
      "|=======================================================================================|\n",
      "|  No running processes found                                                           |\n",
      "+---------------------------------------------------------------------------------------+\n"
     ]
    }
   ],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "84e4060a84544a8daaf2056815fbf6cb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/lib/python3.10/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n",
      "  return self.fget.__get__(instance, owner)()\n",
      "Using the latest cached version of the dataset since mit-han-lab/pile-val-backup couldn't be found on the Hugging Face Hub\n",
      "Found the latest cached dataset configuration 'default' at /root/.cache/huggingface/datasets/mit-han-lab___pile-val-backup/default/0.0.0/2f5e46ae6a69cf0dce4b12f78241c408936ca0e4 (last modified on Sat Feb 17 01:58:45 2024).\n",
      "AWQ: 100%|██████████| 32/32 [13:00<00:00, 24.40s/it]\n"
     ]
    }
   ],
   "source": [
    "from awq import AutoAWQForCausalLM\n",
    "from transformers import AutoTokenizer\n",
    "from transformers import AwqConfig, AutoConfig\n",
    "\n",
    "model_path = \"/root/autodl-tmp/models/opt-6.7b\"\n",
    "quant_path = \"models/opt-6.7b-awq\"\n",
    "quant_config = {\"zero_point\": True, \"q_group_size\": 128, \"w_bit\": 4, \"version\": \"GEMM\"}\n",
    "\n",
    "\n",
    "# 加载模型\n",
    "model = AutoAWQForCausalLM.from_pretrained(model_path)\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)\n",
    "\n",
    "model.quantize(tokenizer, quant_config=quant_config)\n",
    "\n",
    "\n",
    "# 修改配置文件以使其与transformers集成兼容\n",
    "quantization_config = AwqConfig(\n",
    "    bits=quant_config[\"w_bit\"],\n",
    "    group_size=quant_config[\"q_group_size\"],\n",
    "    zero_point=quant_config[\"zero_point\"],\n",
    "    version=quant_config[\"version\"].lower(),\n",
    ").to_dict()\n",
    "\n",
    "# 预训练的transformers模型存储在model属性中，我们需要传递一个字典\n",
    "model.model.config.quantization_config = quantization_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('models/opt-6.7b-awq/tokenizer_config.json',\n",
       " 'models/opt-6.7b-awq/special_tokens_map.json',\n",
       " 'models/opt-6.7b-awq/vocab.json',\n",
       " 'models/opt-6.7b-awq/merges.txt',\n",
       " 'models/opt-6.7b-awq/added_tokens.json',\n",
       " 'models/opt-6.7b-awq/tokenizer.json')"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 保存模型权重\n",
    "model.save_quantized(quant_path)\n",
    "tokenizer.save_pretrained(quant_path)  # 保存分词器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "OptAWQForCausalLM(\n",
       "  (model): OPTForCausalLM(\n",
       "    (model): OPTModel(\n",
       "      (decoder): OPTDecoder(\n",
       "        (embed_tokens): Embedding(50272, 4096, padding_idx=1)\n",
       "        (embed_positions): OPTLearnedPositionalEmbedding(2050, 4096)\n",
       "        (final_layer_norm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
       "        (layers): ModuleList(\n",
       "          (0-31): 32 x OPTDecoderLayer(\n",
       "            (self_attn): OPTAttention(\n",
       "              (k_proj): WQLinear_GEMM(in_features=4096, out_features=4096, bias=True, w_bit=4, group_size=128)\n",
       "              (v_proj): WQLinear_GEMM(in_features=4096, out_features=4096, bias=True, w_bit=4, group_size=128)\n",
       "              (q_proj): WQLinear_GEMM(in_features=4096, out_features=4096, bias=True, w_bit=4, group_size=128)\n",
       "              (out_proj): WQLinear_GEMM(in_features=4096, out_features=4096, bias=True, w_bit=4, group_size=128)\n",
       "            )\n",
       "            (activation_fn): ReLU()\n",
       "            (self_attn_layer_norm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
       "            (fc1): WQLinear_GEMM(in_features=4096, out_features=16384, bias=True, w_bit=4, group_size=128)\n",
       "            (fc2): WQLinear_GEMM(in_features=16384, out_features=4096, bias=True, w_bit=4, group_size=128)\n",
       "            (final_layer_norm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (lm_head): Linear(in_features=4096, out_features=50272, bias=False)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.使用 GPU 加载量化模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(quant_path)\n",
    "model = AutoModelForCausalLM.from_pretrained(quant_path, device_map=\"cuda\").to(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_text(text):\n",
    "    inputs = tokenizer(text, return_tensors=\"pt\").to(0)\n",
    "\n",
    "    out = model.generate(**inputs, max_new_tokens=64)\n",
    "    return tokenizer.decode(out[0], skip_special_tokens=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Merry Christmas! I'm glad to you you!\n",
      "Your!\n",
      " Merry Christmas! Merry Christmas! Merry Christmas!\n",
      "Merry Christmas! I'm very glad to you!\n",
      "Merry Christmas! Christmas, Merry Christmas, Merry Christmas!\n",
      "Merry Christmas! Merry Christmas! I'm very happy to you! Merry Christmas! Merry Christmas! Merry Christmas\n"
     ]
    }
   ],
   "source": [
    "result = generate_text(\"Merry Christmas! I'm glad to\")\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The woman worked as a prostitute. Who gave the woman her money?\n",
      "The prostitute? Who gave me my prostitute?\n"
     ]
    }
   ],
   "source": [
    "result = generate_text(\"The woman worked as a\")\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
